Projekt z Podstaw Sztucznej Inteligencji

W projekcie wykorzystujemy informacje zebrane podczas gry w Affective SpaceShooter 2.

W pierwszym etapie analizujemy dane pod kątem zależności między wynikiem gry a poszczególnymi cechami osobowości.

W drugiej części dokonujemy uczenia nadzorowanego - na podstawie zbioru cech osobowości staramy się przewidzieć średni wynik punktowy gry reprezentowany przez klasy "low" i "medium".

Preprocessing

In [1]:
import pandas as pd
import numpy as np
import csv


md = pd.read_csv("BIRAFFE-metadata.csv", sep=';')
#usunięcie tych rekordów gdzie osoba nie ma danych z gry w space
md = md[pd.notnull(md['SPACE'])]
md = md[pd.notnull(md['OPENNESS'])]
#zostawienie id tych osób, bo pliki mają w nazwie id
ids = md['ID'].values

import csv
import json    
print("start")
pd.set_option('display.max_columns', 500)
#import metadata
data = pd.read_csv("merged_scores.csv", sep=',')
data = data[pd.notnull(data['OPENNESS'])]
data = data[pd.notnull(data['CONSCIENTIOUSNESS'])]
data = data[pd.notnull(data['EXTRAVERSION'])]
data = data[pd.notnull(data['AGREEABLENESS'])]
data = data[pd.notnull(data['NEUROTICISM'])]

data=data[data.Score == 'GameOver']
#data.head(25)
type(data)
mean_c=[]

with open('mean_scores.csv', 'w', newline='') as csvfile:
    #nazwy kolumn- wszystkie z plików json
    fieldnames = ["P_ID","OPENNESS","CONSCIENTIOUSNESS","EXTRAVERSION","AGREEABLENESS","NEUROTICISM","Mean"]
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    #ustawienie nagłówków
    writer.writeheader()
    for my_id in ids:
        data1=data.loc[data['P_ID'] == my_id]
        nr=data1.shape[0]
        #print(data1)
        score=data1['Value'].sum()
        if (score):
            mean = score/nr
            #print(mean)
            new_data={ 'P_ID': data1['P_ID'].iloc[0], 'OPENNESS': data1['OPENNESS'].iloc[0], 'CONSCIENTIOUSNESS': data1['CONSCIENTIOUSNESS'].iloc[0],'EXTRAVERSION':  data1['EXTRAVERSION'].iloc[0],'AGREEABLENESS': data1['AGREEABLENESS'].iloc[0],'NEUROTICISM': data1['NEUROTICISM'].iloc[0],'Mean': mean}
            writer.writerow(new_data)
start

Obliczamy średni wynik gry dla każdej osoby, który będziemy zestawiać z cechami osobowości.

Analiza danych

In [2]:
import matplotlib.pyplot as plt
import plotly.express as px
import seaborn as sns
import pandas as pd
import numpy as np
import csv

pd.set_option('display.max_columns', 500)
#import metadata
data = pd.read_csv("mean_scores.csv", sep=',')
data.head()
fig = px.scatter(data, x = 'OPENNESS', y = 'Mean', title='test')
fig.show()
fig1 = px.scatter(data, x = 'CONSCIENTIOUSNESS', y = 'Mean', title='test')
fig1.show()
fig2 = px.scatter(data, x = 'EXTRAVERSION', y = 'Mean', title='test')
fig2.show()
fig3 = px.scatter(data, x = 'AGREEABLENESS', y = 'Mean', title='test')
fig3.show()
fig4 = px.scatter(data, x = 'NEUROTICISM', y = 'Mean', title='test')
fig4.show()

Powyższe wykresy przedstawiają korelację średniego wyniku gry z każdą z cech osobowości. Można z nich wywnioskować, że zależności nie istnieją.

In [3]:
import matplotlib.pyplot as plt
import seaborn as sns
fig, ax = plt.subplots(figsize=(5,5))
sns.heatmap(data.corr(), vmax=1.0, center=0, fmt='.2f', linewidths=.9, annot=True,cbar_kws={"shrink": .70})
plt.show();

Powyższe zestawienie również pokazuje, że korelacja między badanymi elementami jest niska.

In [4]:
data.boxplot(column=['OPENNESS', 'CONSCIENTIOUSNESS', 'EXTRAVERSION', 'AGREEABLENESS', 'NEUROTICISM'], rot=45)
Out[4]:
<matplotlib.axes._subplots.AxesSubplot at 0x2a5ff844f60>
In [5]:
data.boxplot(column=['Mean'], rot=45)
print("Większość wyników znajduje się pomiędzy około 400 a 1100 punktów, grupa o średniej <=2000 to najbardziej wiarygodna grupa testowa")
Większość wyników znajduje się pomiędzy około 400 a 1100 punktów, grupa o średniej <=2000 to najbardziej wiarygodna grupa testowa
In [6]:
data_trimmed=data.loc[data['Mean'] <= 2000]

fig, ax = plt.subplots(figsize=(5,5))
sns.heatmap(data_trimmed.corr(), vmax=1.0, center=0, fmt='.2f', linewidths=.9, annot=True,cbar_kws={"shrink": .70})
plt.show();

Zestawienie korelacji dla danych testowych, w których średni wynik gry wyniósł <= 2000 punktów. Korelacja między badanymi elementami nadal jest niska.

In [7]:
pd.set_option('display.max_columns', 500)
#import metadata
data = pd.read_csv("mean_scores.csv", sep=',')
data.head()
fig = px.scatter(data_trimmed, x = 'OPENNESS', y = 'Mean', title='test')
fig.show()
fig1 = px.scatter(data_trimmed, x = 'CONSCIENTIOUSNESS', y = 'Mean', title='test')
fig1.show()
fig2 = px.scatter(data_trimmed, x = 'EXTRAVERSION', y = 'Mean', title='test')
fig2.show()
fig3 = px.scatter(data_trimmed, x = 'AGREEABLENESS', y = 'Mean', title='test')
fig3.show()
fig4 = px.scatter(data_trimmed, x = 'NEUROTICISM', y = 'Mean', title='test')
fig4.show()

Niską korelację potwierdzają również wykresy.

In [8]:
from sklearn.model_selection import train_test_split 
from sklearn import metrics
from sklearn import linear_model as ln


def showBarPlot(X, Y, title):
    X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2)

    regressor = ln.LinearRegression()  
    regressor.fit(X_train, y_train) #training the algorithm
    print(regressor.score(X_train, y_train))


    #To retrieve the intercept:
    print(regressor.intercept_)
    #For retrieving the slope:
    print(regressor.coef_)


    y_pred = regressor.predict(X_test)

    df1 = pd.DataFrame({'Actual': y_test, 'Predicted': y_pred})
    df2 = df1.head(25)


    df2.plot(kind='bar', figsize=(10,5))
    plt.grid(which='major', linestyle='-', linewidth='0.5', color='green')
    plt.grid(which='minor', linestyle=':', linewidth='0.5', color='black')
    plt.show()
    
    print(title)
    print('Mean Absolute Error:', metrics.mean_absolute_error(y_test, y_pred))  
    print('Mean Squared Error:', metrics.mean_squared_error(y_test, y_pred))  
    print('Root Mean Squared Error:', np.sqrt(metrics.mean_squared_error(y_test, y_pred)))
    print('\n\n')

    
    
X = data[['OPENNESS']].values
Y = data['Mean'].values #to predict
showBarPlot(X,Y,'OPENNESS - Mean')

X = data[['CONSCIENTIOUSNESS']].values
Y = data['Mean'].values #to predict
showBarPlot(X,Y,'CONSCIENTIOUSNESS - Mean')

X = data[['EXTRAVERSION']].values
Y = data['Mean'].values #to predict
showBarPlot(X,Y,'EXTRAVERSION - Mean')

X = data[['AGREEABLENESS']].values
Y = data['Mean'].values #to predict
showBarPlot(X,Y,'AGREEABLENESS - Mean')

X = data[['NEUROTICISM']].values
Y = data['Mean'].values #to predict
showBarPlot(X,Y,'NEUROTICISM - Mean')
0.005915833051030495
692.4587986116999
[29.64633661]
OPENNESS - Mean
Mean Absolute Error: 642.5904749032688
Mean Squared Error: 886079.810546058
Root Mean Squared Error: 941.3181239868156



0.0006520654704598172
954.5814619417672
[-9.43169635]
CONSCIENTIOUSNESS - Mean
Mean Absolute Error: 566.2891284876794
Mean Squared Error: 552046.8220764056
Root Mean Squared Error: 742.9985343702945



0.011503300158217433
702.1916133632059
[37.32208086]
EXTRAVERSION - Mean
Mean Absolute Error: 585.7911173504766
Mean Squared Error: 636138.153756304
Root Mean Squared Error: 797.5826939924813



0.02561676962008297
524.4197620362361
[47.00775154]
AGREEABLENESS - Mean
Mean Absolute Error: 706.6173157113082
Mean Squared Error: 1205639.4220720944
Root Mean Squared Error: 1098.0161301511441



0.0009316610370181388
804.6776484138129
[9.14616564]
NEUROTICISM - Mean
Mean Absolute Error: 516.6905561369783
Mean Squared Error: 718371.3364571542
Root Mean Squared Error: 847.5678948952433



Niska korelacja wpływa na złe wyniki modelu uczonego za pomocą regresji liniowej, który szuka zależności między średnim wynikiem a poszczególnymi cechami osobowości.

In [9]:
X = data[['OPENNESS']].values
Y = data['CONSCIENTIOUSNESS'].values #to predict
showBarPlot(X,Y,'OPENNESS - CONSCIENTIOUSNESS')

X = data[['OPENNESS']].values
Y = data['EXTRAVERSION'].values #to predict
showBarPlot(X,Y,'OPENNESS - EXTRAVERSION')

X = data[['OPENNESS']].values
Y = data['AGREEABLENESS'].values #to predict
showBarPlot(X,Y,'OPENNESS - AGREEABLENESS')

X = data[['OPENNESS']].values
Y = data['NEUROTICISM'].values #to predict
showBarPlot(X,Y,'OPENNESS - NEUROTICISM')
0.011414851397050452
4.675068399452805
[0.12199042]
OPENNESS - CONSCIENTIOUSNESS
Mean Absolute Error: 1.62499658002736
Mean Squared Error: 4.30605297039361
Root Mean Squared Error: 2.075103122833564



0.05293899767934939
3.631852446966719
[0.27037145]
OPENNESS - EXTRAVERSION
Mean Absolute Error: 1.5960980723976226
Mean Squared Error: 4.426218597179607
Root Mean Squared Error: 2.1038580268591334



4.646700067667541e-05
5.8912513288724915
[0.00759802]
OPENNESS - AGREEABLENESS
Mean Absolute Error: 2.1445406791320116
Mean Squared Error: 6.541150129932005
Root Mean Squared Error: 2.557567228819607



0.00910168320824345
4.925903794442785
[0.13654019]
OPENNESS - NEUROTICISM
Mean Absolute Error: 2.435892839358629
Mean Squared Error: 8.079287684163887
Root Mean Squared Error: 2.8424087820304607



In [10]:
X = data[['CONSCIENTIOUSNESS']].values
Y = data['OPENNESS'].values #to predict
showBarPlot(X,Y,'CONSCIENTIOUSNESS - OPENNESS')

X = data[['CONSCIENTIOUSNESS']].values
Y = data['EXTRAVERSION'].values #to predict
showBarPlot(X,Y,'CONSCIENTIOUSNESS - EXTRAVERSION')

X = data[['CONSCIENTIOUSNESS']].values
Y = data['AGREEABLENESS'].values #to predict
showBarPlot(X,Y,'CONSCIENTIOUSNESS - AGREEABLENESS')

X = data[['CONSCIENTIOUSNESS']].values
Y = data['NEUROTICISM'].values #to predict
showBarPlot(X,Y,'CONSCIENTIOUSNESS - NEUROTICISM')
0.02504809148721676
4.485889446217763
[0.13712824]
CONSCIENTIOUSNESS - OPENNESS
Mean Absolute Error: 1.4859071794092316
Mean Squared Error: 3.15056138700598
Root Mean Squared Error: 1.7749820807563044



0.041640207116332095
4.215893513714337
[0.19803477]
CONSCIENTIOUSNESS - EXTRAVERSION
Mean Absolute Error: 2.1825465151534558
Mean Squared Error: 6.429483604814885
Root Mean Squared Error: 2.535642641385983



0.008505349670295925
5.345655950779198
[0.09542409]
CONSCIENTIOUSNESS - AGREEABLENESS
Mean Absolute Error: 2.0457343278512927
Mean Squared Error: 5.957146567051942
Root Mean Squared Error: 2.440726647343357



0.09119638680826414
7.4111995080518085
[-0.35827664]
CONSCIENTIOUSNESS - NEUROTICISM
Mean Absolute Error: 2.3999218519799634
Mean Squared Error: 8.27316419299563
Root Mean Squared Error: 2.876310865152727



In [11]:
X = data[['EXTRAVERSION']].values
Y = data['OPENNESS'].values #to predict
showBarPlot(X,Y,'EXTRAVERSION - OPENNESS')

X = data[['EXTRAVERSION']].values
Y = data['CONSCIENTIOUSNESS'].values #to predict
showBarPlot(X,Y,'EXTRAVERSION - CONSCIENTIOUSNESS')

X = data[['EXTRAVERSION']].values
Y = data['AGREEABLENESS'].values #to predict
showBarPlot(X,Y,'EXTRAVERSION - AGREEABLENESS')

X = data[['EXTRAVERSION']].values
Y = data['NEUROTICISM'].values #to predict
showBarPlot(X,Y,'EXTRAVERSION - NEUROTICISM')
0.057214981389150044
4.174646354733406
[0.21422742]
EXTRAVERSION - OPENNESS
Mean Absolute Error: 1.5493969894813204
Mean Squared Error: 3.301636088527057
Root Mean Squared Error: 1.8170404752033062



0.038122461437903876
4.552544613350959
[0.19277823]
EXTRAVERSION - CONSCIENTIOUSNESS
Mean Absolute Error: 1.9470651181821041
Mean Squared Error: 5.550776132551586
Root Mean Squared Error: 2.3560085170795935



0.029031032393312395
5.015373352855051
[0.16913303]
EXTRAVERSION - AGREEABLENESS
Mean Absolute Error: 2.4918191800878486
Mean Squared Error: 7.866977600527412
Root Mean Squared Error: 2.8048132915628115



0.08077792009993955
7.306624838531094
[-0.33040536]
EXTRAVERSION - NEUROTICISM
Mean Absolute Error: 2.2027565561501716
Mean Squared Error: 6.904303942077921
Root Mean Squared Error: 2.6276042209735317



In [12]:
X = data[['AGREEABLENESS']].values
Y = data['OPENNESS'].values #to predict
showBarPlot(X,Y,'AGREEABLENESS - OPENNESS')

X = data[['AGREEABLENESS']].values
Y = data['CONSCIENTIOUSNESS'].values #to predict
showBarPlot(X,Y,'AGREEABLENESS - CONSCIENTIOUSNESS')

X = data[['AGREEABLENESS']].values
Y = data['EXTRAVERSION'].values #to predict
showBarPlot(X,Y,'AGREEABLENESS - EXTRAVERSION')

X = data[['AGREEABLENESS']].values
Y = data['NEUROTICISM'].values #to predict
showBarPlot(X,Y,'AGREEABLENESS - NEUROTICISM')
0.00023510013818428543
5.441174819521944
[-0.01183352]
AGREEABLENESS - OPENNESS
Mean Absolute Error: 1.8387291500596916
Mean Squared Error: 5.0511903532463265
Root Mean Squared Error: 2.247485339940247



0.006865224920823643
4.978177947307726
[0.07562317]
AGREEABLENESS - CONSCIENTIOUSNESS
Mean Absolute Error: 1.8191608267541914
Mean Squared Error: 4.920463404743931
Root Mean Squared Error: 2.218211758318833



0.018223337242931126
4.449412798284261
[0.13493366]
AGREEABLENESS - EXTRAVERSION
Mean Absolute Error: 1.9755855867495253
Mean Squared Error: 5.386411573746121
Root Mean Squared Error: 2.3208644022747476



0.0007684475886908793
5.793045243832332
[-0.03187495]
AGREEABLENESS - NEUROTICISM
Mean Absolute Error: 2.6014738799124237
Mean Squared Error: 8.976412472412825
Root Mean Squared Error: 2.996066166227446



In [13]:
X = data[['NEUROTICISM']].values
Y = data['OPENNESS'].values #to predict
showBarPlot(X,Y,'NEUROTICISM - OPENNESS')

X = data[['NEUROTICISM']].values
Y = data['CONSCIENTIOUSNESS'].values #to predict
showBarPlot(X,Y,'NEUROTICISM - CONSCIENTIOUSNESS')

X = data[['NEUROTICISM']].values
Y = data['EXTRAVERSION'].values #to predict
showBarPlot(X,Y,'NEUROTICISM - EXTRAVERSION')

X = data[['NEUROTICISM']].values
Y = data['AGREEABLENESS'].values #to predict
showBarPlot(X,Y,'NEUROTICISM - AGREEABLENESS')
0.03528545083691803
4.426972085143993
[0.14032923]
NEUROTICISM - OPENNESS
Mean Absolute Error: 1.4237699786403473
Mean Squared Error: 3.1513077557116147
Root Mean Squared Error: 1.7751923151342264



0.09054126909121041
6.887755840023143
[-0.24575107]
NEUROTICISM - CONSCIENTIOUSNESS
Mean Absolute Error: 1.8288252452930256
Mean Squared Error: 5.1345884481906605
Root Mean Squared Error: 2.2659630288666803



0.09157319412041909
6.707381597052688
[-0.25330019]
NEUROTICISM - EXTRAVERSION
Mean Absolute Error: 1.6605252489384053
Mean Squared Error: 4.215511926076591
Root Mean Squared Error: 2.053171187718304



0.0004343681456815407
6.097383538619773
[-0.01776083]
NEUROTICISM - AGREEABLENESS
Mean Absolute Error: 1.610656498507388
Mean Squared Error: 4.47897064558892
Root Mean Squared Error: 2.1163578727589814



Niska korelacja wpływa również na złe wyniki modelu uczonego za pomocą regresji liniowej, który szuka zależności między poszczególnymi cechami osobowości.

Uczenie nadzorowane

Z analizy wynika, że nie ma szczególnych zależności między pojedynczymi cechami osobowości a wynikiem gry.

Rozpatrzymy zatem cały zbiór cech i dokonamy na nich uczenia nadzorowanego - predykcji wyniku gry w klasach "low" i "medium".

Dokonujemy porównania różnych modeli.

In [18]:
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier,NearestCentroid
from sklearn.svm import SVC
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.naive_bayes import GaussianNB, MultinomialNB
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from sklearn import linear_model
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn import preprocessing
from sklearn.metrics import classification_report, f1_score
import random 

data = pd.read_csv("mean_scores.csv", sep=',')
data.head()

h = .02  # step size in the mesh


def mapping(x):
    if x < data['Mean'].median() :
        return 'low'
    else:
        return 'high'
            
data['Mean_Class'] = data['Mean'].map(lambda x: mapping(x));


X = data[['OPENNESS', 'CONSCIENTIOUSNESS', 'EXTRAVERSION', 'AGREEABLENESS', 'NEUROTICISM']].values

y = data['Mean_Class'].values #to predict

names = ["Nearest Centroid ", #przypisuje do obserwacji etykietę klasy próbek treningowych, których średnia jest najbliższa obserwacji
         "Nearest Neighbors", #zależność między zmiennymi objaśniającymi a objaśnianymi jest złożona lub nietypowa
         "Linear SVC", #Support Vector Classification  
         "Gaussian Process", #bazuje na aproksymacji Laplace'a - dwie klasy
         "Decision Tree", 
         "Random Forest", 
         "Naive Bayes",#Naiwne klasyfikatory bayesowskie są oparte na założeniu o wzajemnej niezależności predyktorów
         "Logistic Regression"] #zmienna zależna przyjmuje tylko dwie wartości

classifiers = [
    NearestCentroid(),
    KNeighborsClassifier(3),
    SVC(kernel="linear", C=0.025),
    GaussianProcessClassifier(1.0 * RBF(1.0)),
    DecisionTreeClassifier(max_depth=5),
    RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1),
    GaussianNB(),
    linear_model.LogisticRegression(solver='lbfgs')]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42, shuffle=True)

#randomowy przydzial
y_pred = np.empty((y_test.size),dtype = '<U6' )
for i in range (0,y_test.size):
    y_pred[i] = random.choice(['low','high'])
print(classification_report(y_test, y_pred))
print("Średni wynik randomowego przydziału oscyluje ok.0.5")
print("")
      
for name, classifier in zip(names, classifiers):
    classifier.fit(X_train, y_train)
    # wyniki walidacji krzyzowej dla danego estymatora - określają jakość modelu
    # Testowanie każdego podzbioru używając pozostałych jako zbiór treningowy.
    # cv - określa krotność walidacji - przy niedużych zbiorach najczęściej k=10
    scores = cross_val_score(classifier, X, y, cv=10).tolist()
    print(name)
    print("Walidacja krzyżowa:", np.mean(scores))
    predictions = classifier.predict(X_test)
    print(classification_report(y_test, predictions))
    print("")
              precision    recall  f1-score   support

        high       0.52      0.50      0.51        26
         low       0.46      0.48      0.47        23

    accuracy                           0.49        49
   macro avg       0.49      0.49      0.49        49
weighted avg       0.49      0.49      0.49        49

Średni wynik randomowego przydziału oscyluje ok.0.5

Nearest Centroid 
Walidacja krzyżowa: 0.5946428571428573
              precision    recall  f1-score   support

        high       0.63      0.65      0.64        26
         low       0.59      0.57      0.58        23

    accuracy                           0.61        49
   macro avg       0.61      0.61      0.61        49
weighted avg       0.61      0.61      0.61        49


Nearest Neighbors
Walidacja krzyżowa: 0.5741071428571428
              precision    recall  f1-score   support

        high       0.71      0.65      0.68        26
         low       0.64      0.70      0.67        23

    accuracy                           0.67        49
   macro avg       0.67      0.67      0.67        49
weighted avg       0.68      0.67      0.67        49


Linear SVC
Walidacja krzyżowa: 0.6160714285714286
              precision    recall  f1-score   support

        high       0.65      0.50      0.57        26
         low       0.55      0.70      0.62        23

    accuracy                           0.59        49
   macro avg       0.60      0.60      0.59        49
weighted avg       0.60      0.59      0.59        49


Gaussian Process
Walidacja krzyżowa: 0.59375
              precision    recall  f1-score   support

        high       0.63      0.46      0.53        26
         low       0.53      0.70      0.60        23

    accuracy                           0.57        49
   macro avg       0.58      0.58      0.57        49
weighted avg       0.59      0.57      0.57        49


Decision Tree
Walidacja krzyżowa: 0.6187500000000001
              precision    recall  f1-score   support

        high       0.54      0.54      0.54        26
         low       0.48      0.48      0.48        23

    accuracy                           0.51        49
   macro avg       0.51      0.51      0.51        49
weighted avg       0.51      0.51      0.51        49


Random Forest
Walidacja krzyżowa: 0.6241071428571427
              precision    recall  f1-score   support

        high       0.65      0.58      0.61        26
         low       0.58      0.65      0.61        23

    accuracy                           0.61        49
   macro avg       0.61      0.61      0.61        49
weighted avg       0.62      0.61      0.61        49


Naive Bayes
Walidacja krzyżowa: 0.5205357142857142
              precision    recall  f1-score   support

        high       0.67      0.46      0.55        26
         low       0.55      0.74      0.63        23

    accuracy                           0.59        49
   macro avg       0.61      0.60      0.59        49
weighted avg       0.61      0.59      0.58        49


Logistic Regression
Walidacja krzyżowa: 0.6276785714285714
              precision    recall  f1-score   support

        high       0.64      0.54      0.58        26
         low       0.56      0.65      0.60        23

    accuracy                           0.59        49
   macro avg       0.60      0.60      0.59        49
weighted avg       0.60      0.59      0.59        49


Wyniki walidacji krzyżowej przekraczają 0.5, zatem można uznać, że modele całkiem dobrze radzą sobie z predykcją. Najlepsze wyniki osiągają modele: "Linear SVM", "Decision Tree" oraz "Logistic Regression".

F₁ jest kolejną miarą dokładności testu. W tym przypadku najlepsze wyniki osiągają modele: "Nearest Neighbors" oraz "Nearest Centroid".

Biorąc pod uwagę średnią obu wartości, najlepszym modelem okazuje się "Nearest Neighbors".

In [19]:
print("Porównanie wyników przy zmianie liczby sąsiadów:")
xx =[]
yy = []
for i in [1,2,3,5,10,15,20,30,40,50]:
    classifier = KNeighborsClassifier(i)
    classifier.fit(X_train, y_train)
    scores = cross_val_score(classifier, X, y, cv=10).tolist()
    predictions = classifier.predict(X_test)
    xx.append(np.mean(scores))
    yy.append(np.mean(f1_score(y_test, predictions, average=None)))

labels = [1,2,3,5,10,15,20,30,40,50]
def plot_bar_x():
    index = np.arange(len(labels))
    plt.bar(index, xx)
    plt.xlabel('Liczba sąsiadów')
    plt.xticks(index, labels, fontsize=10)
    plt.show()
def plot_bar_y():
    index = np.arange(len(labels))
    plt.bar(index, yy)
    plt.xlabel('Liczba sąsiadów')
    plt.xticks(index, labels, fontsize=10)
    plt.show()

print("")
print("Zmiana liczby sąsiadów a wartość walidacji krzyżowej:")
plot_bar_x()
print("Zmiana liczby sąsiadów a wartość F1:")
plot_bar_y()
Porównanie wyników przy zmianie liczby sąsiadów:

Zmiana liczby sąsiadów a wartość walidacji krzyżowej:
Zmiana liczby sąsiadów a wartość F1:

Z wykresów wynika, że najoptymalniejsza liczba sąsiadów znajduje się w przedziale <2,10>, ponieważ w obu przypadkach jakość modelu jest na wysokim poziomie.

In [ ]: